#include "mpi.h"
#include <iostream>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
using namespace std;

/*
extern double drand48();
*/

/* Pipeline version of the algorithm... */
/* we really need the velocities as well... */
typedef struct {
  double x, y, z;
  double mass;
} Particle;
/* We use leapfrog for the time integration ... */
typedef struct {
  double xold, yold, zold;
  double fx, fy, fz;
} ParticleV;

void InitParticles(Particle[], ParticleV[], int);
double ComputeForces(Particle[], Particle[], ParticleV[], int);
double ComputeNewPos(Particle[], ParticleV[], int, double, MPI_Comm);

#define MAX_PARTICLES 4000
#define MAX_P 128
int main(int argc, char *argv[]) {
  Particle particles[MAX_PARTICLES]; /* Particles on ALL nodes */
  ParticleV pv[MAX_PARTICLES];       /* Particle velocity */
  Particle sendbuf[MAX_PARTICLES],   /* Pipeline buffers */
      recvbuf[MAX_PARTICLES];
  MPI_Request request[2];
  int counts[MAX_P],                   /* Number on each processor */
      displs[MAX_P];                   /* Offsets into particles */
  int rank, size, npart, i, j, offset; /* location of local particles */
  int totpart,                         /* total number of particles */
      cnt;                             /* number of times in loop */
  MPI_Datatype particletype;
  double sim_t; /* Simulation time */
  double time;  /* Computation time */
  int pipe, left, right, periodic;
  MPI_Comm commring;
  MPI_Status statuses[2];

  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);

  /* Get the best ring in the topology */
  periodic = 1;
  MPI_Cart_create(MPI_COMM_WORLD, 1, &size, &periodic, 1, &commring);
  MPI_Cart_shift(commring, 0, 1, &left, &right);

  // ==========================================
  MPI_Barrier(commring);
  if (rank == 2) {
    cout << left << "  " << right << endl;
    MPI_Abort(commring, 4);
  }
  MPI_Barrier(commring);
  // ==========================================

  /* Everyone COULD have a different size ... */
  if (argc < 2) {
    fprintf(stderr, "Usage: %s n\n", argv[0]);
    MPI_Abort(MPI_COMM_WORLD, 1);
  }
  npart = atoi(argv[1]) / size;

  if (npart * size > MAX_PARTICLES) {
    fprintf(stderr, "%d is too many; max is %d\n", npart * size, MAX_PARTICLES);
    MPI_Abort(MPI_COMM_WORLD, 1);
  }

  MPI_Type_contiguous(4, MPI_DOUBLE, &particletype);
  MPI_Type_commit(&particletype);

  /* Get the sizes and displacements */
  MPI_Allgather(&npart, 1, MPI_INT, counts, 1, MPI_INT, commring);
  displs[0] = 0;
  for (i = 1; i < size; i++)
    displs[i] = displs[i - 1] + counts[i - 1];
  totpart = displs[size - 1] + counts[size - 1];

  /* Generate the initial values */
  InitParticles(particles, pv, npart);
  offset = displs[rank];
  cnt = 10;

  time = MPI_Wtime();
  sim_t = 0.0;
  while (cnt--) {
    double max_f, max_f_seg;

    /* Load the initial sendbuffer */
    memcpy(sendbuf, particles, npart * sizeof(Particle));
    max_f = 0.0;
    for (pipe = 0; pipe < size; pipe++) {
      if (pipe != size - 1) {
        MPI_Isend(sendbuf, npart, particletype, right, pipe, commring,
                  &request[0]);
        MPI_Irecv(recvbuf, npart, particletype, left, pipe, commring,
                  &request[1]);
      }
      /* Compute forces (2D only) */
      max_f_seg = ComputeForces(particles, sendbuf, pv, npart);
      if (max_f_seg > max_f)
        max_f = max_f_seg;
      /* Push pipe */
      if (pipe != size - 1)
        MPI_Waitall(2, request, statuses);
      memcpy(sendbuf, recvbuf, counts[pipe] * sizeof(Particle));
    }
    /* Once we have the forces, we compute the changes in position */
    sim_t += ComputeNewPos(particles, pv, npart, max_f, commring);

    /* We could do graphics here (move particles on the display) */
  }
  time = MPI_Wtime() - time;
  if (rank == 0) {
    printf("Computed %d particles in %f seconds\n", totpart, time);
  }
  MPI_Finalize();
  return 0;
}

void InitParticles(Particle particles[], ParticleV pv[], int npart) {
  int i;
  for (i = 0; i < npart; i++) {
    particles[i].x = drand48();
    particles[i].y = drand48();
    particles[i].z = drand48();
    particles[i].mass = 1.0;
    pv[i].xold = particles[i].x;
    pv[i].yold = particles[i].y;
    pv[i].zold = particles[i].z;
    pv[i].fx = 0;
    pv[i].fy = 0;
    pv[i].fz = 0;
  }
}

double ComputeForces(Particle myparticles[], Particle others[], ParticleV pv[],
                     int npart) {
  double max_f, rmin;
  int i, j;

  max_f = 0.0;
  for (i = 0; i < npart; i++) {
    double xi, yi, mi, rx, ry, mj, r, fx, fy;
    rmin = 100.0;
    xi = myparticles[i].x;
    yi = myparticles[i].y;
    fx = 0.0;
    fy = 0.0;
    for (j = 0; j < npart; j++) {
      rx = xi - others[j].x;
      ry = yi - others[j].y;
      mj = others[j].mass;
      r = rx * rx + ry * ry;
      /* ignore overlap and same particle */
      if (r == 0.0)
        continue;
      if (r < rmin)
        rmin = r;
      /* compute forces */
      r = r * sqrt(r);
      fx -= mj * rx / r;
      fy -= mj * ry / r;
    }
    pv[i].fx += fx;
    pv[i].fy += fy;
    /* Compute a rough estimate of (1/m)|df / dx| */
    fx = sqrt(fx * fx + fy * fy) / rmin;
    if (fx > max_f)
      max_f = fx;
  }
  return max_f;
}

double ComputeNewPos(Particle particles[], ParticleV pv[], int npart,
                     double max_f, MPI_Comm commring) {
  int i;
  double a0, a1, a2;
  static double dt_old = 0.001, dt = 0.001;
  double dt_est, new_dt, dt_new;

  /* integation is a0 * x^+ + a1 * x + a2 * x^- = f / m */
  a0 = 2.0 / (dt * (dt + dt_old));
  a2 = 2.0 / (dt_old * (dt + dt_old));
  a1 = -(a0 + a2); /* also -2/(dt*dt_old) */

  for (i = 0; i < npart; i++) {
    double xi, yi;
    /* Very, very simple leapfrog time integration.  We use a variable
       step version to simplify time-step control. */
    xi = particles[i].x;
    yi = particles[i].y;
    particles[i].x = (pv[i].fx - a1 * xi - a2 * pv[i].xold) / a0;
    particles[i].y = (pv[i].fy - a1 * yi - a2 * pv[i].yold) / a0;
    pv[i].xold = xi;
    pv[i].yold = yi;
    pv[i].fx = 0;
    pv[i].fy = 0;
  }

  /* Recompute a time step. Stability criteria is roughly
     2/sqrt(1/m |df/dx|) >= dt.  We leave a little room */
  dt_est = 1.0 / sqrt(max_f);
  /* Set a minimum: */
  if (dt_est < 1.0e-6)
    dt_est = 1.0e-6;
  MPI_Allreduce(&dt_est, &dt_new, 1, MPI_DOUBLE, MPI_MIN, commring);
  /* Modify time step */
  if (dt_new < dt) {
    dt_old = dt;
    dt = dt_new;
  } else if (dt_new > 4.0 * dt) {
    dt_old = dt;
    dt *= 2.0;
  }

  return dt_old;
}